from __future__ import print_function

import os
import sys
import argparse
import time
import math

import torch
import torch.backends.cudnn as cudnn
from torchvision import transforms, datasets

from util import TwoCropTransform, AverageMeter
from util import adjust_learning_rate, warmup_learning_rate
from util import set_optimizer, save_model, plot_singular_values, get_entropy_energy_based_rank
from networks.resnet_big import SupConResNet, SimSiam, BarlowTwinsModel, DirectDLRModel
from losses import SupConLoss, SimSiamLoss, BarlowTwinsLoss, DirectDLRLoss
import wilds
from wilds.common.data_loaders import get_ssl_train_loader
from wilds.common.grouper import CombinatorialGrouper
from main_linear import main as linear_eval
from main_linear import parse_option as linear_parse_option
import numpy as np
import random
import wandb
import torch.nn.functional as F

from wilds.datasets.boolean_dataset import SpuriousBooleanSampling, SpuriousBooleanFinite, ContrastiveBooleanDataset
from util import generate_parity_func, RepresentationTracker, estimate_coeff, spur_attr, core_attr, avg_core_coeff, prune_projection_head
from util import make_spur_mask, split_core_spur, augmentation_gap

try:
    import apex
    from apex import amp, optimizers
except ImportError:
    pass


def parse_option():
    parser = argparse.ArgumentParser('argument for training')

    parser.add_argument('--print_freq', type=int, default=10,
                        help='print frequency')
    parser.add_argument('--save_freq', type=int, default=50,
                        help='save frequency')
    parser.add_argument('--batch_size', type=int, default=256,
                        help='batch_size')
    parser.add_argument('--num_workers', type=int, default=32,
                        help='num of workers to use')
    parser.add_argument('--epochs', type=int, default=1000,
                        help='number of training epochs')

    # optimization
    parser.add_argument('--learning_rate', type=float, default=0.01,
                        help='learning rate')
    parser.add_argument('--lr_decay_epochs', type=str, default='700,800,900',
                        help='where to decay lr, can be a list')
    parser.add_argument('--lr_decay_rate', type=float, default=0.1,
                        help='decay rate for learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-4,
                        help='weight decay')
    parser.add_argument('--momentum', type=float, default=0.9,
                        help='momentum')

    # model dataset
    parser.add_argument('--model', type=str, default='resnet18_large', choices=['two-layer-nn', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet18_large', 'resnet50_large']) 
    parser.add_argument('--head', type=str, default='mlp', choices=['linear', 'mlp', 'fixed', 'identity'])
    parser.add_argument('--kappa', type=float, default=1.0)

    parser.add_argument('--dataset', type=str, default='metashift',
                        choices=['boolean', 'cifar10', 'cifar100', 'spur_cifar10', 'waterbirds', 'cmnist', 'celebA', 'metashift', 'path'], help='dataset')
    parser.add_argument('--mean', type=str, help='mean of dataset in path in form of str tuple')
    parser.add_argument('--std', type=str, help='std of dataset in path in form of str tuple')
    parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset')
    parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop')

    # method
    parser.add_argument('--method', type=str, default='SimSiam',
                        choices=['SupCon', 'SimCLR', 'SimSiam', 'BarlowTwins', 'DirectDLR'], help='choose method')

    # temperature
    parser.add_argument('--temp', type=float, default=0.5,
                        help='temperature for loss function')

    # other setting
    parser.add_argument('--cosine', action='store_true', 
                        help='using cosine annealing')
    parser.add_argument('--syncBN', action='store_true',
                        help='using synchronized batch normalization')
    parser.add_argument('--warm', action='store_true',
                        help='warm-up for large batch training')
    parser.add_argument('--trial', type=str, default='0',
                        help='id for recording multiple runs')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--use_wandb', action='store_true')
    parser.add_argument('--wandb_name', default='spur-ssl-project')
    parser.add_argument('--entity', default='spur-ssl')
    parser.add_argument('--energy_threshold', type=float, default=0.9)
    parser.add_argument('--rank_threshold', type=float, default=0.1)
    parser.add_argument('--augmented_features', action='store_true')
    parser.add_argument('--decor_reg', type=float, default=0.0)
    parser.add_argument('--decor_loss', action='store_true')
    parser.add_argument('--ent_reg', type=float, default=0.0)
    parser.add_argument('--ent_loss', action='store_true')
    parser.add_argument('--spec_reg', type=float, default=0.0)  #{0.001, ..., 0.1}
    parser.add_argument('--spec_loss', action='store_true')
    parser.add_argument('--train_set_linear_layer', type=str, default='ds_train', choices=['val', 'train', 'balanced_train', 'ds_train', 'us_train'])
    parser.add_argument('--plot_path', type=str, default='test/eigenvalues_projections',
                        help='path to save the plots')
    parser.add_argument('--spur_str', type=float, default=0.95)
    parser.add_argument('--tracker_dir', type=str, default='simclr_rep_logs')
    parser.add_argument('--pruning_amount', type=float, default=0.0)

    opt = parser.parse_args()

    # check if dataset is path that passed required arguments
    if opt.dataset == 'path':
        assert opt.data_folder is not None \
            and opt.mean is not None \
            and opt.std is not None

    # set the path according to the environment
    if opt.data_folder is None:
        opt.data_folder = './datasets/'
    opt.model_path = './save/{}/{}_models'.format(opt.method, opt.dataset)

    iterations = opt.lr_decay_epochs.split(',')
    opt.lr_decay_epochs = list([])
    for it in iterations:
        opt.lr_decay_epochs.append(int(it))

    opt.model_name = '{}_{}_temp_{}_trial_{}'.\
        format(opt.method, opt.dataset, opt.temp, opt.trial)

    if opt.cosine:
        opt.model_name = '{}_cosine'.format(opt.model_name)
    
    opt.model_name = '{}_{}_{}_{}_{}_{}'.format(opt.model_name, opt.seed, opt.decor_reg, opt.ent_reg, opt.spec_reg, opt.augmented_features)

    # warm-up for large-batch training,
    if opt.batch_size > 256:
        opt.warm = True
    if opt.warm:
        opt.model_name = '{}_warm'.format(opt.model_name)
        opt.warmup_from = 0.01
        opt.warm_epochs = 10
        if opt.cosine:
            eta_min = opt.learning_rate * (opt.lr_decay_rate ** 3)
            opt.warmup_to = eta_min + (opt.learning_rate - eta_min) * (
                    1 + math.cos(math.pi * opt.warm_epochs / opt.epochs)) / 2
        else:
            opt.warmup_to = opt.learning_rate

    opt.save_folder = os.path.join(opt.model_path, opt.model_name)
    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)

    return opt


def set_loader(opt):
    
    # construct data loader
    if opt.dataset == 'boolean': 
        mean = 0.0
        std = 1.0
    elif opt.dataset == 'cifar10':
        opt.size = 32
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2023, 0.1994, 0.2010)
    elif opt.dataset == 'cifar100':
        opt.size = 32
        mean = (0.5071, 0.4867, 0.4408)
        std = (0.2675, 0.2565, 0.2761)
    elif opt.dataset == 'spur_cifar10':
        opt.size = 32
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        groupby_fields = ['background', 'y']
    elif opt.dataset == 'waterbirds':
        opt.size = 224 
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        groupby_fields = ['background', 'y']
    elif opt.dataset == 'cmnist':
        opt.size = 32
        mean = (0.1307, 0.1307, 0.)
        std = (0.3081, 0.3081, 0.3081)
        groupby_fields = ['background', 'y']
    elif opt.dataset == 'metashift': 
        opt.size = 224
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        groupby_fields = ['env', 'y']
    elif opt.dataset == 'celebA': 
        opt.size = 224
        mean = (0.485, 0.456, 0.406)
        std = (0.229, 0.224, 0.225)
        groupby_fields = ['male', 'y']
    elif opt.dataset == 'path':
        opt.size = 224
        mean = eval(opt.mean)
        std = eval(opt.std)
    else: 
        raise ValueError('dataset not supported: {}'.format(opt.dataset))
    normalize = transforms.Normalize(mean=mean, std=std)

    if opt.dataset != 'boolean':
        if opt.method == 'SimSiam': 
            train_transform = transforms.Compose([
            transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([
                transforms.GaussianBlur(kernel_size=opt.size // 20 * 2 + 1, sigma=(0.1, 2.0))
            ], p=0.5),   # <-- ADD Gaussian Blur with 50% probability
            transforms.ToTensor(),
            normalize,
        ])
        else: 

            train_transform = transforms.Compose([
                transforms.RandomResizedCrop(size=opt.size, scale=(0.2, 1.)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.ToTensor(),
                normalize,
            ])

    if opt.dataset == 'boolean': 
        core_function = generate_parity_func([0,1,2])
        spurious_function = generate_parity_func([0,1])
        
        train_base_dataset = SpuriousBooleanFinite(core_len=10, spurious_len=10, 
                                  core_func=core_function, spurious_func=spurious_function, 
                                  c=0.9, sample_num=60000, batch_size=opt.batch_size,
                                  sampling_method="pure",
                                  device="cpu")

        train_dataset = ContrastiveBooleanDataset(train_base_dataset, crop_ratio=0.8)

    elif opt.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10(root=opt.data_folder,
                                         transform=TwoCropTransform(train_transform),
                                         download=True)
    elif opt.dataset == 'cifar100':
        train_dataset = datasets.CIFAR100(root=opt.data_folder,
                                          transform=TwoCropTransform(train_transform),
                                          download=True)
    elif opt.dataset == 'path':
        train_dataset = datasets.ImageFolder(root=opt.data_folder,
                                            transform=TwoCropTransform(train_transform))
    else: 
        full_dataset = wilds.get_dataset(
        dataset=opt.dataset,
        root_dir='./datasets',
        split_scheme='official')
        # spur_str=opt.spur_str)
        train_grouper = CombinatorialGrouper(dataset=full_dataset, groupby_fields=groupby_fields)
        train_dataset = full_dataset.get_subset('train', frac=1., transform=TwoCropTransform(train_transform)) #frac=1.
        # from PIL import Image
        # sample_1 = train_dataset.dataset.get_input(0)
        # img_1_resized = sample_1.resize((256, 256), resample=Image.NEAREST)
        # img_1_resized.save('train_sample_1_resize.png')  # Or sample.save("debug_sample.png")

        # sample_2 = train_dataset.dataset.get_input(1)
        # img_2_resized = sample_2.resize((256, 256), resample=Image.NEAREST)
        # img_2_resized.save('train_sample_2_resize.png')  # Or sample.save("debug_sample.png")

        # from torchvision.transforms.functional import to_pil_image

        # # grab the first training example *with* SimCLR augments
        # # WILDS returns (x, y, metadata); after TwoCropTransform x is (view1, view2)
        # (x1, x2), y, metadata = train_dataset[0]

        # # tensors → PIL so they save nicely (remove .cpu() if already on CPU)
        
        # view1 = to_pil_image(x1.cpu())   # x1 is CHW tensor in [0,1]
        # view2 = to_pil_image(x2.cpu())
        # view1 = view1.resize((256, 256), resample=Image.NEAREST)
        # view2 = view2.resize((256, 256), resample=Image.NEAREST)
        # view1.save("train_sample_view1.png")
        # view2.save("train_sample_view2.png")

        # (x11, x22), y, metadata = train_dataset[1]

        # view11 = to_pil_image(x11.cpu())   # x1 is CHW tensor in [0,1]
        # view22 = to_pil_image(x22.cpu())
        # view11 = view11.resize((256, 256), resample=Image.NEAREST)
        # view22 = view22.resize((256, 256), resample=Image.NEAREST)
        # view11.save("train_sample_view11.png")
        # view22.save("train_sample_view22.png")


    if opt.dataset == 'boolean' or opt.dataset == 'cifar10' or opt.dataset == 'cifar100' or opt.dataset == 'path':
        train_sampler = None
        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=opt.batch_size, shuffle=(train_sampler is None),
            drop_last=False, num_workers=opt.num_workers, pin_memory=True, sampler=train_sampler) 
    else: 
        train_loader = get_ssl_train_loader(
            loader='standard',
            dataset=train_dataset,
            batch_size=opt.batch_size,
            uniform_over_groups=False,
            grouper=train_grouper,
            drop_last=False,
            num_workers=opt.num_workers,
            pin_memory=True)

    return train_loader


def set_model(opt):

    if opt.method == 'SimSiam':
        model = SimSiam(name=opt.model)
        criterion = SimSiamLoss(augmented_features=opt.augmented_features, decor_reg=opt.decor_reg, 
                                decor_loss=opt.decor_loss, ent_reg=opt.ent_reg, ent_loss=opt.ent_loss, spec_reg=opt.spec_reg, 
                                spec_loss=opt.spec_loss)

    elif opt.method == 'SimCLR':
        model = SupConResNet(name=opt.model, head=opt.head, k=opt.kappa)
        criterion = SupConLoss(temperature=opt.temp, augmented_features=opt.augmented_features, decor_reg=opt.decor_reg, 
                               decor_loss=opt.decor_loss, ent_reg=opt.ent_reg, ent_loss=opt.ent_loss, spec_reg=opt.spec_reg, 
                               spec_loss=opt.spec_loss)
    elif opt.method == 'BarlowTwins': 
        model = BarlowTwinsModel(name=opt.model)
        criterion = BarlowTwinsLoss(lambda_param=5e-3) #1e-2

    elif opt.method == 'DirectDLR': 
        model = DirectDLRModel(name=opt.model)
        criterion = DirectDLRLoss(temperature=opt.temp)
    
    else: 
        raise ValueError('contrastive method not supported: {}'.
                             format(opt.method))

    # enable synchronized Batch Normalization
    if opt.syncBN:
        model = apex.parallel.convert_syncbn_model(model)

    if torch.cuda.is_available():
        if torch.cuda.device_count() > 1:
            model.encoder = torch.nn.DataParallel(model.encoder)
        model = model.cuda()
        criterion = criterion.cuda()
        cudnn.benchmark = True

    return model, criterion


def train(train_loader, model, criterion, optimizer, epoch, opt, tracker, global_step):
    """one epoch training"""
    model.train()

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    decor_losses = AverageMeter()
    entropy_losses = AverageMeter()
    spectral_flattening_losses = AverageMeter()

    # spur_mask = make_spur_mask()

    end = time.time()
    
    for idx, data in enumerate(train_loader):
    # for idx, (image, _) in enumerate(train_loader):
        data_time.update(time.time() - end)
        image = data[0]
        if torch.cuda.is_available():
            image[0] = image[0].cuda(non_blocking=True)
            image[1] = image[1].cuda(non_blocking=True)
        # gaps = []
        # for img0, img1 in zip(image[0], image[1]):
        #     # core0, spur0 = split_core_spur(img0, spur_mask)
        #     # core1, spur1 = split_core_spur(img1, spur_mask)

        #     # d_core = F.mse_loss(core0, core1, reduction='mean').sqrt()
        #     # d_spur = F.mse_loss(spur0, spur1, reduction='mean').sqrt()
        #     # gaps.append(d_core / (d_spur + 1e-6))

        #     gap = augmentation_gap(img0, img1, spur_mask, mode="lpips")  # or mode="lpips"
        #     gaps.append(gap)

        # print("augmentation-gap (batch mean) :", torch.stack(gaps).mean().item())
        # exit()

        bsz = image[0].size(0)
        # bsz = labels.shape[0]

        # warm-up learning rate
        warmup_learning_rate(opt, epoch, idx, len(train_loader), optimizer)

        # compute loss
        if opt.method == 'SimCLR':

            images = torch.cat([image[0], image[1]], dim=0)
            # inter_features = model.encoder(images)
            features = model(images)

            # features = F.normalize(model.head(inter_features), dim=1)
            f1, f2 = torch.split(features, [bsz, bsz], dim=0)
            features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
            loss, decor_loss, entropy_loss, spectral_flattening_loss = criterion(features)
            

        elif opt.method == 'SimSiam': 
            h1, h2, z1, z2, p1, p2 = model(image[0], image[1])
            loss, decor_loss, entropy_loss, spectral_flattening_loss = criterion(h1, p1, p2, z1, z2)

        elif opt.method == 'BarlowTwins': 
            images = torch.cat([image[0], image[1]], dim=0)  # [2*B, C, H, W]
            features = model(images)                          # [2*B, d]
            features = features.view(bsz, 2, -1)              # Reshape to [B, 2, d]
            loss, decor_loss, entropy_loss, spectral_flattening_loss = criterion(features)

        elif opt.method == 'DirectDLR': 
            h1_sub = model(image[0], return_subvector_only=True)
            h2_sub = model(image[1], return_subvector_only=True)
            loss, decor_loss, entropy_loss, spectral_flattening_loss = criterion(h1_sub, h2_sub)

        else:
            raise ValueError('contrastive method not supported: {}'.
                             format(opt.method))

        # if opt.use_wandb:
        #     tracker.update(inter_features, features, step=global_step+idx)
        #     # wandb.log(tracker, step=epoch+idx)
        #     tracker.log_to_wandb(global_step+idx)
        #     global_step += idx

        # update metric
        losses.update(loss.item(), bsz)
        decor_losses.update(decor_loss.item(), bsz)
        entropy_losses.update(entropy_loss.item(), bsz)
        spectral_flattening_losses.update(spectral_flattening_loss.item(), bsz)

        # SGD
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # print info
        if (idx + 1) % opt.print_freq == 0:
            print('Train: [{0}][{1}/{2}]\t'
                  'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'DT {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                   epoch, idx + 1, len(train_loader), batch_time=batch_time,
                   data_time=data_time, loss=losses))
            sys.stdout.flush()

    return losses.avg, decor_losses.avg, entropy_losses.avg, spectral_flattening_losses.avg, global_step

def get_features(model, data_loader, opt): 

    model.eval()

    # extract features
    features = []
    with torch.no_grad():  # no gradients needed for feature extraction
        for idx, data in enumerate(data_loader):
        # for idx, (image, _) in enumerate(data_loader):

            image = data[0]
            if opt.augmented_features == True:
                images = torch.cat([image[0], image[1]], dim=0)
            else: 
                images = image[0]
            if torch.cuda.is_available():
                images = images.cuda(non_blocking=True)
        
            # pass images through the encoder
            embeddings = model.encoder(images) # get features
        
            # save features
            features.append(embeddings.cpu())

    # concatenate all features
    features = torch.cat(features, dim=0)

    # normalize
    features = F.normalize(features, dim=1)
    
    print("Extracted features shape:", features.shape)  # (N, feature_dim)
    return features


def main():
    global_step = 0 
    opt = parse_option()
    print(opt)

    # convert the argparse Namespace to a dictionary
    hyperparameters = vars(opt)  

    # replace with the keys you want
    selected_keys = ['method', 'dataset', 'spec_loss', 'spec_reg', 'trial', 'seed']
    filtered_hyperparameters = {key: hyperparameters[key] for key in selected_keys if key in hyperparameters}

    # formatted string for the run name
    run_name = "_".join([f"{value}" for key, value in filtered_hyperparameters.items()])

    # setup the seed
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)
    random.seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    # enable deterministic algorithms
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # setup wandb
    if opt.use_wandb:
        wandb.init(
            project=opt.wandb_name,
            name=run_name,
            config=hyperparameters,  
            entity=opt.entity 
        )

    tracker = RepresentationTracker(augmented_features=opt.augmented_features, save_dir=opt.tracker_dir, track_projector=True)

    # build data loader
    train_loader = set_loader(opt)
    # build model and criterion
    model, criterion = set_model(opt)

    # build optimizer
    optimizer = set_optimizer(opt, model)

    # training routine
    for epoch in range(1, opt.epochs + 1):
        adjust_learning_rate(opt, optimizer, epoch)

        # train for one epoch
        time1 = time.time()
        train_loss, decor_loss, entropy_loss, spectral_flattening_loss, global_step = train(train_loader, model, criterion, optimizer, epoch, opt, tracker, global_step)
        time2 = time.time()
        print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1))
        # prune_projection_head(model, amount=opt.pruning_amount, finalise=False)

        if epoch % opt.print_freq == 0: 

            # get the entropy and energy based rank of the training features after every 10 epochs.
            train_features = get_features(model, train_loader, opt)
            entropy, effective_rank, energy_based_rank = get_entropy_energy_based_rank(train_features, opt)
            print('epoch {},'.format(epoch), 'entropy {:.2f},'.format(entropy), 
                  'effective rank {},'.format(effective_rank), 'and energy-based rank {}'.format(energy_based_rank))
            if opt.use_wandb: 
                wandb_dict_1 = {'Entropy': entropy, 'Effective rank': effective_rank, 'Energy-based rank': energy_based_rank}
                wandb.log(wandb_dict_1, step=epoch) #step=global_step)

        # if opt.use_wandb and epoch % opt.print_freq == 0:
            # spur_coeff_encoder, spur_coeff_projector = estimate_coeff(train_loader, spur_attr, model, opt)
            # mean_core_coeff_encoder, mean_core_coeff_projector = avg_core_coeff(train_loader, model, core_attr, opt, num_classes=10)

            # # core_coeff_encoder, core_coeff_projector =val_loader estimate_coeff(train_loader, core_attr, model, opt)
            # gap_encoder = spur_coeff_encoder - mean_core_coeff_encoder
            # gap_projector = spur_coeff_projector - mean_core_coeff_projector
            # wandb.log({"Fourier/spur encoder": spur_coeff_encoder,
            #            "Fourier/core encoder": mean_core_coeff_encoder,
            #            "Fourier/gap encoder": gap_encoder,
            #            "Fourier/spur projector": spur_coeff_projector,
            #            "Fourier/core projector": mean_core_coeff_projector,
            #            "Fourier/gap projector": gap_projector}, step=epoch)


            wandb_dict = {'SSL train loss': train_loss, 
                          'SSL decor loss': decor_loss, 
                          'SSL entropy loss': entropy_loss, 
                          'SSL spectral flattening loss': spectral_flattening_loss}
            wandb_dict.update({'SSL learning rate': optimizer.param_groups[0]['lr']})
            wandb.log(wandb_dict, step=epoch)


        if epoch % opt.save_freq == 0:

            # save the model
            # prune_projection_head(model, amount=0.0, finalise=True)   # remove masks
            save_file = os.path.join(
                opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch))
            save_model(model, optimizer, opt, epoch, save_file)

            # get the arguments for the linear model
            linear_args = argparse.Namespace(
                print_freq=10, save_freq=50, batch_size=opt.batch_size, num_workers=32, epochs=100, learning_rate=1.0, lr_decay_epochs='60,75,90', lr_decay_rate=0.2,
                weight_decay=0, momentum=0.9, method=opt.method, model=opt.model, head=opt.head, kappa=opt.kappa, dataset=opt.dataset, cosine=opt.cosine, warm=False, ckpt=save_file, use_wandb=opt.use_wandb, 
                train_set_linear_layer=opt.train_set_linear_layer, augmented_features=opt.augmented_features, plot_path=opt.plot_path, rank_threshold=opt.rank_threshold, energy_threshold=opt.energy_threshold, 
                spur_str=opt.spur_str, seed=opt.seed)
            linear_args.data_folder = './datasets/'
            iterations = linear_args.lr_decay_epochs.split(',')
            linear_args.lr_decay_epochs = list([])
            for it in iterations:
                linear_args.lr_decay_epochs.append(int(it))
            if linear_args.dataset == 'cifar10' or linear_args.dataset == 'spur_cifar10':
                linear_args.n_cls = 10
            elif linear_args.dataset == 'cifar100':
                linear_args.n_cls = 100
            elif opt.dataset == 'waterbirds': 
                linear_args.n_cls = 2
            elif opt.dataset == 'cmnist': 
                linear_args.n_cls = 2
            elif opt.dataset == 'metashift': 
                linear_args.n_cls = 2
            elif opt.dataset == 'celebA': 
                linear_args.n_cls = 2
            elif opt.dataset == 'boolean':
                linear_args.n_cls = 2
            else:
                raise ValueError('dataset not supported: {}'.format(opt.dataset))

            # train and validation of the linear model
            linear_eval(linear_args, supcon_epoch=epoch) #supcon_epoch=global_step


    # save the last model
    save_file = os.path.join(
        opt.save_folder, 'last.pth')
    save_model(model, optimizer, opt, opt.epochs, save_file)


if __name__ == '__main__':
    main()
